#  Copyright (c) 2022, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import itertools
import os
import tempfile

import numpy as np
import pytest

import coremltools as ct
from coremltools import proto
from coremltools._deps import _HAS_TF_1, _HAS_TF_2, MSG_TF1_NOT_FOUND
from coremltools.converters.mil.testing_reqs import backends, compute_units
from coremltools.converters.mil.testing_utils import (
    assert_cast_ops_count,
    assert_input_dtype,
    assert_ops_in_mil_program,
    assert_output_dtype,
    assert_prog_input_type,
    assert_prog_output_type,
    assert_spec_input_image_type,
    assert_spec_output_image_type,
    get_op_types_in_program,
    verify_prediction,
)
from coremltools.test.api.test_api_examples import TestInputs as _TestInputs

tf = pytest.importorskip("tensorflow")

#################################################################################
# Note: all tests are also used as examples in https://coremltools.readme.io/docs
# as a reference.
# Whenever any of the following test fails, we should update API documentations
#################################################################################


@pytest.mark.skipif(not _HAS_TF_1, reason=MSG_TF1_NOT_FOUND)
@pytest.mark.skipif(ct.utils._macos_version() < (10, 15), reason='Model produces specification 4.')
class TestTensorFlow1ConverterExamples:
    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_convert_from_frozen_graph(tmpdir, backend):
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
            y = tf.nn.relu(x, name="output")

        mlmodel = ct.convert(graph, convert_to=backend[0], compute_units=ct.ComputeUnit.CPU_ONLY)

        test_input = np.random.rand(1, 2, 3) - 0.5
        with tf.compat.v1.Session(graph=graph) as sess:
            expected_val = sess.run(y, feed_dict={x: test_input})
        results = mlmodel.predict({"input": test_input})
        np.testing.assert_allclose(results["output"], expected_val)

    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_convert_from_frozen_graph_file(tmpdir, backend):
        # create the model to convert

        # write a toy frozen graph
        # Note that we usually needs to run freeze_graph() on tf.Graph()
        # skipping here as this toy model does not contain any variables
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
            y = tf.nn.relu(x, name="output")

        save_path = str(tmpdir)
        tf.io.write_graph(graph, save_path, "frozen_graph.pb", as_text=False)

        # Create a test sample
        # -0.5 to have some negative values
        test_input = np.random.rand(1, 2, 3) - 0.5
        with tf.compat.v1.Session(graph=graph) as sess:
            expected_val = sess.run(y, feed_dict={x: test_input})

        # The input `.pb` file is a frozen graph format that usually
        # generated by TensorFlow's utility function `freeze_graph()`
        pb_path = os.path.join(save_path, "frozen_graph.pb")

        # 3 ways to specify inputs:
        # (1) Fully specify inputs
        mlmodel = ct.convert(
            pb_path,
            # We specify inputs with name matching the placeholder name.
            inputs=[ct.TensorType(name="input", shape=(1, 2, 3))],
            outputs=["output"],
            convert_to=backend[0],
        )

        # (2) Specify input TensorType without name (when there's only one
        # input)
        mlmodel = ct.convert(
            pb_path,
            # TensorType name is optional when there's only one input.
            inputs=[ct.TensorType(shape=(1, 2, 3))],
            outputs=["output"],
            convert_to=backend[0],
        )

        # (3) Not specify inputs at all. `inputs` is optional for TF. When
        # inputs is not specified, convert() infers inputs from Placeholder
        # nodes.
        mlmodel = ct.convert(
            pb_path,
            outputs=["output"],
            convert_to=backend[0],
            compute_units=ct.ComputeUnit.CPU_ONLY,
        )

        results = mlmodel.predict({"input": test_input})
        np.testing.assert_allclose(results["output"], expected_val)
        suffix = ".mlmodel" if backend[0] == "neuralnetwork" else ".mlpackage"
        mlmodel_path = os.path.join(save_path, "model" + suffix)
        # Save the converted model
        mlmodel.save(mlmodel_path)

        results = mlmodel.predict({"input": test_input})
        np.testing.assert_allclose(results["output"], expected_val, atol=1e-3)

    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_convert_from_saved_model_dir(tmpdir, backend):
        # Sample input
        test_input = np.random.rand(1, 3, 5) - 0.5

        # create the model to convert
        with tf.compat.v1.Session() as sess:
            x = tf.placeholder(shape=(1, 3, 5), dtype=tf.float32)
            y = tf.nn.relu(x)

            expected_val = sess.run(y, feed_dict={x: test_input})

        # Save model as SavedModel
        inputs = {"x": x}
        outputs = {"y": y}
        save_path = str(tmpdir)
        tf.compat.v1.saved_model.simple_save(sess, save_path, inputs, outputs)

        # SavedModel directory generated by TensorFlow 1.x
        # when converting from SavedModel dir, inputs / outputs are optional
        mlmodel = ct.convert(
            save_path, convert_to=backend[0], compute_units=ct.ComputeUnit.CPU_ONLY
        )

        # Need input output names to call mlmodel
        # x.name == 'Placeholder:0'. Strip out ':0'
        input_name = x.name.split(":")[0]
        results = mlmodel.predict({input_name: test_input})
        # y.name == 'Relu:0'. output_name == 'Relu'
        output_name = y.name.split(":")[0]
        np.testing.assert_allclose(results[output_name], expected_val)


    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_freeze_and_convert_matmul_graph(backend):
        # testing : https://coremltools.readme.io/docs/tensorflow-1#export-as-frozen-graph-and-convert
        graph = tf.Graph()
        with graph.as_default():
            x = tf.placeholder(tf.float32, shape=[None, 20], name="input")
            W = tf.Variable(tf.truncated_normal([20, 10], stddev=0.1))
            b = tf.Variable(tf.ones([10]))
            y = tf.matmul(x, W) + b
            output_names = [y.op.name]

        from tensorflow.python.tools.freeze_graph import freeze_graph

        model_dir = tempfile.TemporaryDirectory()
        graph_def_file = os.path.join(model_dir.name, "tf_graph.pb")
        checkpoint_file = os.path.join(model_dir.name, "tf_model.ckpt")
        frozen_graph_file = os.path.join(model_dir.name, "tf_frozen.pb")

        with tf.Session(graph=graph) as sess:
            # initialize variables
            sess.run(tf.global_variables_initializer())
            # save graph definition somewhere
            tf.train.write_graph(
                sess.graph, model_dir.name, graph_def_file, as_text=False
            )
            # save the weights
            saver = tf.train.Saver()
            saver.save(sess, checkpoint_file)

            # take the graph definition and weights
            # and freeze into a single .pb frozen graph file
            freeze_graph(input_graph=graph_def_file,
                         input_saver="",
                         input_binary=True,
                         input_checkpoint=checkpoint_file,
                         output_node_names=",".join(output_names),
                         restore_op_name="save/restore_all",
                         filename_tensor_name="save/Const:0",
                         output_graph=frozen_graph_file,
                         clear_devices=True,
                         initializer_nodes="")
        print("Tensorflow frozen graph saved at {}".format(frozen_graph_file))
        ct.convert(frozen_graph_file, convert_to=backend[0])

    @staticmethod
    def test_convert_tf1_frozen_graph_to_milinternal(tmpdir):
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
            y = tf.nn.relu(x, name="output")

        model = ct.convert(graph, convert_to='milinternal')
        assert isinstance(model, ct.converters.mil.Program)

    @staticmethod
    def test_mil_op_names_consistency(tmpdir):
        '''
        Test to make sure that when the same model is converted to MIL program,
        in the same session, it gives the same program, with the same op names
        '''
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 5, 5, 3), name="input")
            conv = tf.nn.conv2d(
                x,
                filter = tf.constant(np.random.rand(1, 1, 3, 5), tf.float32),
                padding = "VALID",
            )
            y = tf.nn.relu(conv, name="output")

        mil_prog1 = ct.convert(graph, convert_to='milinternal')
        # convert the same model again
        mil_prog2 = ct.convert(graph, convert_to='milinternal')

        # compare op names of the two programs
        np.testing.assert_array_equal(get_op_types_in_program(mil_prog1), get_op_types_in_program(mil_prog2))

###############################################################################
# Note: Stress tests for TF1 input / output types
###############################################################################
@pytest.mark.skipif(ct.utils._macos_version() < (10, 15), reason='Model produces specification 4.')
@pytest.mark.skipif(not _HAS_TF_1, reason=MSG_TF1_NOT_FOUND)
class TestTf1Inputs(_TestInputs):
    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_input_noname(backend):
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
            x1 = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input_1")
            y = tf.nn.relu(x, name="output")
            y1 = tf.nn.relu(x1, name="output_1")

        with pytest.raises(ValueError) as e:
            model = ct.convert(
                graph,
                inputs=[ct.TensorType(shape=(1, 2, 3))],
                convert_to=backend[0],
            )
        expected_error = "Multiple inputs are found in graph, but no input name was provided"
        assert expected_error == str(e.value)

    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    def test_input_wrongname(backend):
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
            x1 = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input_1")
            y = tf.nn.relu(x, name="output")
            y1 = tf.nn.relu(x1, name="output_1")

        with pytest.raises(ValueError) as e:
            model = ct.convert(
                graph,
                inputs=[ct.TensorType(shape=(1, 2, 3), name="wrong_input")],
                convert_to=backend[0],
            )
        expected_error = "Multiple inputs are found in graph, but no input name was provided"
        expected_error = "Input ({}) provided is not found in given tensorflow graph. Placeholders in graph are: {}".format("wrong_input", ["input", "input_1"])
        assert expected_error == str(e.value)

    @pytest.mark.parametrize(
        "backend, compute_unit",
        itertools.product(
            backends,
            compute_units,
        ),
    )
    def test_input_dynamic_without_inputs_param(self, backend, compute_unit):
        """The `inputs` param is not provided for a dynamic input (shape has `None`)."""
        with tf.Graph().as_default() as graph:
            x = tf.placeholder(tf.float32, shape=(None, None, 3), name="input")
            x1 = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input_1")
            y = tf.nn.relu(x, name="output")
            y1 = tf.nn.relu(x1, name="output_1")

        convert_to = backend[0]
        if convert_to == "mlprogram":
            with pytest.warns(
                UserWarning,
                match="Some dimensions in the input shape are unknown, hence they are set to "
                "flexible ranges with lower bound and default value = 1, and upper bound = 2. "
            ):
                mlmodel = ct.convert(
                    graph,
                    convert_to=convert_to,
                    compute_units=compute_unit,
                )
        else:
            mlmodel = ct.convert(
                graph,
                convert_to=convert_to,
                compute_units=compute_unit,
            )

        spec = mlmodel.get_spec()
        assert list(spec.description.input[0].type.multiArrayType.shape) == [1, 1, 3]
        assert (
            spec.description.input[0].type.multiArrayType.shapeRange.sizeRanges[1].lowerBound == 1
        )
        assert (
            spec.description.input[0].type.multiArrayType.shapeRange.sizeRanges[1].upperBound == -1
            if convert_to == "neuralnetwork"
            else 2
        )

    @staticmethod
    @pytest.mark.parametrize(
        "backend",
        backends,
    )
    @pytest.mark.skipif(not ct.utils._is_macos(), reason="test needs predictions")
    def test_tf_predict_input(backend):
        TestTf1Inputs._test_variant_input_type_prediction(tf.convert_to_tensor, backend[0])

@pytest.fixture
def uint8_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.uint8, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.uint8), name="output")
    return graph


@pytest.fixture
def int8_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int8, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.int8), name="output")
    return graph


@pytest.fixture
def int32_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int32, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.int32), name="output")
    return graph


@pytest.fixture
def int32_two_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int32, shape=[10, 20], name="input1")
        y = tf.placeholder(tf.int32, shape=[10, 20], name="input2")
        out = tf.add(x, y, name="output")
    return graph


@pytest.fixture
def int32_two_output_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int32, shape=[10, 20], name="input1")
        y = tf.placeholder(tf.int32, shape=[10, 20], name="input2")
        out1 = tf.add(x, 1, name="output1")
        out2 = tf.add(y, 1, name="output2")
    return graph


@pytest.fixture
def int32_float32_two_output_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[10, 20], name="input1")
        y = tf.placeholder(tf.float32, shape=[10, 20], name="input2")
        x_add = tf.add(x, 1.0, name="output1")
        y_add = tf.add(y, 1.0)
        y_cast = tf.cast(y_add, dtype=tf.int32, name="output2")
    return graph


@pytest.fixture
def int32_float32_two_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int32, shape=[10, 20], name="input1")
        y = tf.placeholder(tf.float32, shape=[10, 20], name="input2")
        x_cast = tf.cast(x, dtype=tf.float32)
        out = tf.add(x_cast, y, name="output")
    return graph

@pytest.fixture
def float32_input_model_add_op():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5.5, dtype=tf.float32), name="output")
    return graph

@pytest.fixture
def float32_input_model_relu_ops():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[10, 20], name="input")
        x1 = tf.nn.relu(x)
        out = tf.nn.relu(x1, name="output")
    return graph

@pytest.fixture
def int64_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.int64, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.int64), name="output")
    return graph

@pytest.fixture
def float32_two_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[10, 20], name="input1")
        y = tf.placeholder(tf.float32, shape=[10, 20], name="input2")
        out = tf.add(x, y, name="output")
    return graph

@pytest.fixture
def float32_two_output_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[10, 20], name="input")
        y = tf.nn.relu(x)
        out2 = tf.nn.relu6(x, name="output2")
        out1 = tf.nn.relu(y, name="output1")
    return graph

@pytest.fixture
def float64_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float64, shape=[10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.float64), name="output")
    return graph


@pytest.fixture
def rank3_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 10, 20], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.float32), name="output")
    return graph

@pytest.fixture
def rank4_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 10, 20, 3], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.float32), name="output")
    return graph

@pytest.fixture
def rank4_input_model_with_channel_first_output():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 10, 20, 3], name="input")
        y = tf.add(x, tf.constant(5, dtype=tf.float32))
        out = tf.transpose(y, perm=[0, 3, 1, 2], name="output")
    return graph

@pytest.fixture
def rank4_grayscale_input_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 10, 20, 1], name="input")
        out = tf.add(x, tf.constant(5, dtype=tf.float32), name="output")
    return graph

@pytest.fixture
def rank4_grayscale_input_model_with_channel_first_output():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 10, 20, 1], name="input")
        y = tf.add(x, tf.constant(5, dtype=tf.float32))
        out = tf.transpose(y, perm=[0, 3, 1, 2], name="output")
    return graph

@pytest.fixture
def linear_model():
    if not _HAS_TF_1:
        pytest.skip(MSG_TF1_NOT_FOUND)
    # this model will test the fuse_matmul_weight_bias pass
    with tf.Graph().as_default() as graph:
        x = tf.placeholder(tf.float32, shape=[1, 2], name="input")
        y = tf.matmul(x, tf.constant([1, 2], shape=(2, 4), dtype=tf.float32))
        y = tf.add(y, tf.constant([1, 2, 3, 4], shape=(4,), dtype=tf.float32))
        out = tf.nn.relu(y)
    return graph


@pytest.mark.skipif(ct.utils._macos_version() < (13, 0), reason='Tests are for deployment target ios16/macos13')
class TestInputOutputConversionAPI:

    def test_input_dtype_inferred(self, int32_input_model):
        # test that the input dtype is picked up from TF correctly
        mlmodel = ct.convert(int32_input_model, minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

    def test_unsupported_input_dtype_in_tf_graph_uint8(self, uint8_input_model):
        # test that no error is raised when no dtype is provided by the user,
        # and the TF graph's input dtype is not supported.
        # In this case, it will be mapped to the closest supported dtype
        mlmodel = ct.convert(uint8_input_model, minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

    def test_unsupported_input_dtype_in_tf_graph_int8(self, int8_input_model):
        # test that no error is raised when no dtype is provided by the user,
        # and the TF graph's input dtype is not supported.
        # In this case, it will be mapped to the closest supported dtype
        mlmodel = ct.convert(int8_input_model, minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

    def test_unsupported_input_dtype_in_tf_graph_int64(self, int64_input_model):
        # test that no error is raised when no dtype is provided by the user,
        # and the TF graph's input dtype is not supported.
        # In this case, it will be mapped to the closest supported dtype
        mlmodel = ct.convert(int64_input_model,
                             minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

    def test_unsupported_input_dtype_in_tf_graph_fp64(self, float64_input_model):
        # test that no error is raised when no dtype is provided by the user,
        # and the TF graph's input dtype is not supported.
        # In this case, it will be mapped to the closest supported dtype
        mlmodel = ct.convert(float64_input_model, minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

    def test_input_dtype_user_provided(self, int32_input_model):
        # test that provided dtype in the api overrides the input dtype in the TF model
        mlmodel = ct.convert(int32_input_model,
                             inputs=[ct.TensorType(dtype=np.float32)],
                             minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="fp32")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

    def test_invalid_input_dtype(self, int32_input_model):
        # error should be raised if a dtype is provided by the user that is not supported
        with pytest.raises(TypeError,
                           match="is unsupported for inputs/outputs of the model"
                           ):
            mlmodel = ct.convert(int32_input_model,
                                 inputs=[ct.TensorType(dtype=np.int16)],
                                 minimum_deployment_target=ct.target.macOS12)

        with pytest.raises(TypeError,
                           match="float16 dtype for inputs is only supported for deployment target >= iOS16/macOS13"
                           ):
            mlmodel = ct.convert(int32_input_model,
                                 inputs=[ct.TensorType(dtype=np.float16)],
                                 minimum_deployment_target=ct.target.macOS12)

    def test_fp16_input_dtype(self, float32_input_model_add_op, float32_input_model_relu_ops, int32_input_model):
        """
        Test that providing fp16 input dtype works with macOS13.
        """
        mlmodel = ct.convert(
            float32_input_model_add_op,
            inputs=[ct.TensorType(dtype=np.float16)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add", "cast"])
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

        mlmodel = ct.convert(
            float32_input_model_relu_ops,
            inputs=[ct.TensorType(dtype=np.float16)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        # Two consecutive relus are merged in the `merge_consecutive_relus` pass.
        assert_ops_in_mil_program(mlmodel, expected_op_list=["relu", "cast"])
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

        mlmodel = ct.convert(
            int32_input_model,
            inputs=[ct.TensorType(dtype=np.float16)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add", "cast"])
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

    def test_fp16_input_dtype_fp32_precision(self, float32_input_model_add_op, float32_input_model_relu_ops,
                                             int32_input_model):
        """
        Same test as test_fp16_input_dtype, but with Float32 precision
        """
        mlmodel = ct.convert(
            float32_input_model_add_op,
            inputs=[ct.TensorType(dtype=np.float16)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
            compute_precision=ct.precision.FLOAT32,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "add"])
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

        mlmodel = ct.convert(
            float32_input_model_relu_ops,
            inputs=[ct.TensorType(dtype=np.float16)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
            compute_precision=ct.precision.FLOAT32,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "relu"])
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp32")

    def test_two_input_model(self, float32_two_input_model):
        # test forcing input type of "input1" to be int32
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[ct.TensorType(name="input1", dtype=np.int32)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS12,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", expected_name="input1")
        assert_input_dtype(mlmodel, expected_type_str="fp32", expected_name="input2")
        assert_output_dtype(mlmodel, expected_type_str="fp32")

        # test forcing both inputs to be int32
        mlmodel = ct.convert(float32_two_input_model,
                             inputs=[ct.TensorType(name="input1", dtype=np.int32),
                                     ct.TensorType(name="input2", dtype=np.int32),
                                     ],
                             minimum_deployment_target=ct.target.macOS12)
        assert_input_dtype(mlmodel, expected_type_str="int32", expected_name="input1")
        assert_input_dtype(mlmodel, expected_type_str="int32", expected_name="input2")
        assert_output_dtype(mlmodel, expected_type_str="int32")

        # if names are not provided an error should be raised
        with pytest.raises(ValueError):
            mlmodel = ct.convert(float32_two_input_model,
                                 inputs=[ct.TensorType(dtype=np.int32),
                                         ct.TensorType(dtype=np.int32),
                                         ],
                                 minimum_deployment_target=ct.target.macOS12)

        # test forcing both inputs to be float16
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", dtype=np.float16),
                ct.TensorType(name="input2", dtype=np.float16),
            ],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16", expected_name="input1")
        assert_input_dtype(mlmodel, expected_type_str="fp16", expected_name="input2")
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        assert_cast_ops_count(mlmodel, expected_count=1)
        verify_prediction(mlmodel)

    def test_single_output_model(self, int32_input_model, float32_input_model_relu_ops):
        # test output type
        mlmodel = ct.convert(int32_input_model,
                             minimum_deployment_target=ct.target.macOS12)
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add"])
        assert_output_dtype(mlmodel, expected_type_str="int32")

        # test that error is raised when an output of unknown name is provided
        with pytest.raises(Exception):
            # output name does not exist in the model
            mlmodel = ct.convert(int32_input_model,
                                 outputs=["z"],
                                 minimum_deployment_target=ct.target.macOS12)

        # test that error is raised when two outputs are provided without names
        with pytest.raises(ValueError, match=", does not have names"):
            mlmodel = ct.convert(int32_input_model,
                                 outputs=[ct.TensorType(dtype=np.float32), ct.TensorType(dtype=np.float32)],
                                 minimum_deployment_target=ct.target.macOS12)

        # test that an error is raised when shape is provided for the output
        with pytest.raises(ValueError):
            mlmodel = ct.convert(int32_input_model,
                                 outputs=[ct.TensorType(dtype=np.float32, shape=(10, 20))],
                                 minimum_deployment_target=ct.target.macOS12)

        # test that the output dtype provided by the user is applied during conversion
        mlmodel = ct.convert(int32_input_model,
                             outputs=[ct.TensorType(dtype=np.float32)],
                             minimum_deployment_target=ct.target.macOS12)
        assert_output_dtype(mlmodel, expected_type_str="fp32", expected_name="Identity" if _HAS_TF_2 else "output")
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add", "cast"])

        # test that output dtype of float16 is rejected when deployment target is low
        with pytest.raises(TypeError,
                           match="float16 dtype for outputs is only supported for deployment target >= iOS16/macOS13"
                           ):
            ct.convert(float32_input_model_relu_ops,
                       outputs=[ct.TensorType(dtype=np.float16)],
                       minimum_deployment_target=ct.target.macOS12,
                       )

        # test that output type float16 is applied correctly
        mlmodel = ct.convert(
            float32_input_model_relu_ops,
            inputs=[ct.TensorType(name="input", dtype=np.float32)],
            outputs=[ct.TensorType(dtype=np.float16)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_output_dtype(
            mlmodel, expected_type_str="fp16", expected_name="Identity" if _HAS_TF_2 else "output"
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "relu"])

        # test that input and output types float16 are applied correctly
        mlmodel = ct.convert(float32_input_model_relu_ops,
                             inputs=[ct.TensorType(dtype=np.float16)],
                             outputs=[ct.TensorType(dtype=np.float16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp16", expected_name="Identity" if _HAS_TF_2 else "output")
        assert_ops_in_mil_program(mlmodel, expected_op_list=["relu"])
        verify_prediction(mlmodel)

    def test_multi_output_model(self, float32_two_output_model):
        # check that error is raised when only 1 output provided
        with pytest.raises(ValueError, match="please provide names for each of the outputs"):
            mlmodel = ct.convert(float32_two_output_model,
                                 outputs=[ct.TensorType(dtype=np.float16)],
                                 minimum_deployment_target=ct.target.macOS13,
                                 )

        # check that error is raised when multiple outputs are provided without names
        with pytest.raises(ValueError, match="please provide names for each of the outputs"):
            mlmodel = ct.convert(float32_two_output_model,
                                 outputs=[ct.TensorType(dtype=np.float16), ct.TensorType(dtype=np.float32)],
                                 minimum_deployment_target=ct.target.macOS13,
                                 )

        # set 1 output to float16 and the other to float32
        output1_name = "Identity" if _HAS_TF_2 else "output1"
        output2_name = "Identity_1" if _HAS_TF_2 else "output2"
        mlmodel = ct.convert(float32_two_output_model,
                             inputs=[ct.TensorType(dtype=np.float16)],
                             outputs=[ct.TensorType(name=output2_name, dtype=np.float16),
                                      ct.TensorType(name=output1_name, dtype=np.float32)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_cast_ops_count(mlmodel, expected_count=1)
        assert_output_dtype(mlmodel, expected_type_str="fp16", expected_name=output2_name, index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp32", expected_name=output1_name, index=1)
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

        # in this case only the single output will be selected
        mlmodel = ct.convert(float32_two_output_model,
                             inputs=[ct.TensorType(dtype=np.float16)],
                             outputs=[ct.TensorType(name=output2_name, dtype=np.float16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_cast_ops_count(mlmodel, expected_count=0)
        assert_output_dtype(mlmodel, expected_type_str="fp16", expected_name=output2_name, index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

    def test_color_input(self, rank4_input_model, rank3_input_model):
        mlmodel = ct.convert(
            rank4_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "transpose", "add", "cast"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.RGB
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        verify_prediction(mlmodel)

        with pytest.raises(ValueError, match="must have rank 4"):
            mlmodel = ct.convert(rank3_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
                                 minimum_deployment_target=ct.target.macOS12,
                                 )

    def test_grayscale_input(self, rank4_input_model, rank3_input_model, rank4_grayscale_input_model):
        with pytest.raises(ValueError, match="must have rank 4"):
            mlmodel = ct.convert(rank3_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
                                 minimum_deployment_target=ct.target.macOS13,
                                 )

        # invalid shape
        with pytest.raises(ValueError):
            mlmodel = ct.convert(rank4_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
                                 minimum_deployment_target=ct.target.macOS13,
                                 )

        mlmodel = ct.convert(
            rank4_grayscale_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "transpose", "add", "cast"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        verify_prediction(mlmodel)

        with pytest.raises(TypeError, match="float16 dtype for inputs is only supported for deployment target >= iOS16/macOS13"):
            mlmodel = ct.convert(rank4_grayscale_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                                 minimum_deployment_target=ct.target.macOS12,
                                 )

        # test that grayscale_16 raises error when used with neural network
        with pytest.raises(TypeError, match="float16 dtype for inputs is only supported for deployment target >= iOS16/macOS13"):
            mlmodel = ct.convert(rank4_grayscale_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                                 )

        mlmodel = ct.convert(rank4_grayscale_input_model,
                             inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                             outputs=[ct.TensorType(dtype=np.float16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["transpose", "add"])
        assert_spec_input_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

    def test_color_output(self, rank4_input_model, rank4_input_model_with_channel_first_output):
        # check that an error is raised if the output shape is not of form (1, 3, H, W)
        with pytest.raises(ValueError, match="Shape of the RGB/BGR image output,"):
            mlmodel = ct.convert(rank4_input_model,
                                 inputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
                                 outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
                                 minimum_deployment_target=ct.target.macOS13,
                                 )

        mlmodel = ct.convert(rank4_input_model_with_channel_first_output,
                             inputs=[ct.ImageType(color_layout=ct.colorlayout.BGR)],
                             outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "add", "cast"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.BGR
        )
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.RGB
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        verify_prediction(mlmodel)

        # check neural network conversion
        mlmodel = ct.convert(
            rank4_input_model_with_channel_first_output,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
            outputs=[ct.ImageType(color_layout=ct.colorlayout.BGR)],
            convert_to="neuralnetwork",
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.RGB
        )
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.BGR
        )
        verify_prediction(mlmodel)

    def test_grayscale_output(self, rank4_grayscale_input_model, rank4_grayscale_input_model_with_channel_first_output):
        # check that an error is raised if the output shape is not of form (1, 1, H, W)
        with pytest.raises(ValueError, match="Shape of the Grayscale image output,"):
            mlmodel = ct.convert(
                rank4_grayscale_input_model,
                inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
                outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
                convert_to="neuralnetwork",
            )

        with pytest.raises(TypeError, match="float16 dtype for outputs is only supported for deployment target >= iOS16/macOS13"):
            mlmodel = ct.convert(rank4_grayscale_input_model_with_channel_first_output,
                                 outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                                 minimum_deployment_target=ct.target.macOS12,
                                 )

        mlmodel = ct.convert(
            rank4_grayscale_input_model_with_channel_first_output,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
            outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
            convert_to="neuralnetwork",
        )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["add"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        verify_prediction(mlmodel)

        mlmodel = ct.convert(rank4_grayscale_input_model_with_channel_first_output,
                             inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                             outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_cast_ops_count(mlmodel, expected_count=0)
        assert_spec_input_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        assert_spec_output_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)

        mlmodel = ct.convert(rank4_grayscale_input_model_with_channel_first_output,
                             inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
                             outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_ops_in_mil_program(mlmodel, expected_op_list=["cast", "add"])
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        assert_spec_output_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)


    def test_linear_model(self, linear_model):
        # this will test the fuse_matmul_weight_bias pass, when the inputs are of type float16
        mlmodel = ct.convert(linear_model,
                             inputs=[ct.TensorType(dtype=np.float16)],
                             outputs=[ct.TensorType(dtype=np.float16)],
                             minimum_deployment_target=ct.target.macOS13,
                             )
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        assert_output_dtype(mlmodel, expected_type_str="fp16")
        assert_ops_in_mil_program(mlmodel, ["linear", "relu"])
        verify_prediction(mlmodel)

    def test_default_input_dtype(self, int32_input_model, int32_two_input_model):
        """
        If ``dtype`` is not specified, it defaults to the ``dtype`` of the
              inputs in the TF model.
        """
        # Case 1: Single input model with no dtype specified
        mlmodel = ct.convert(
            int32_input_model,
            inputs=[ct.TensorType(shape=(10, 20))],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

        # Case 2: two inputs model with dtype specified for the first input
        mlmodel = ct.convert(
            int32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20), dtype=np.float16),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_input_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)

        # Case 3: two inputs model with dtype specified for the second input
        mlmodel = ct.convert(
            int32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20), dtype=np.float16),
            ],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 4: two inputs model with no dtype specified for both inputs
        mlmodel = ct.convert(
            int32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.macOS13,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)


class TestiOS16DefaultIODtype:
    def test_iO16_default_fp16_input(
        self,
        float32_input_model_add_op,
        int32_input_model,
    ):
        """
        With minimum_deployment_target set >= iOS16, if the compute precision is
        set to fp16. By default, a fp16 i/o model is produced for fp32 models.
        However, if the users specify the dtype, the converter is going to respect that.
        """
        # Case 1: fp32 single input model
        mlmodel = ct.convert(
            float32_input_model_add_op,
            inputs=[ct.TensorType(shape=(10, 20))],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

        # Case 2: fp32 single input model
        mlmodel = ct.convert(
            float32_input_model_add_op,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

        # Case 3: int32 single input model. No change made.
        mlmodel = ct.convert(
            int32_input_model,
            inputs=[ct.TensorType(shape=(10, 20))],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

    def test_iO16_default_fp16_multiple_input(
        self,
        float32_two_input_model,
        int32_two_input_model,
        int32_float32_two_input_model,
    ):
        # Case 1: fp32 two inputs model. First input dtype missing
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20), dtype=np.float32),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=1)
        verify_prediction(mlmodel)

        # Case 2: fp32 two inputs model. Second input dtype missing
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20), dtype=np.float32),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 3: fp32 two inputs model. Both dtype missing
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 4: fp32 two inputs model. inputs not given
        mlmodel = ct.convert(
            float32_two_input_model,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 5: fp32 two inputs model. Both dtype given
        mlmodel = ct.convert(
            float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20), dtype=np.int32),
                ct.TensorType(name="input2", shape=(10, 20), dtype=np.float32),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=1)
        verify_prediction(mlmodel)

        # Case 6: int32 two inputs model. Both dtype missing. No change made.
        mlmodel = ct.convert(
            int32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)

        # Case 7: mixed dtype model with two inputs. Both dtype missing. The fp32 input is cast to fp16.
        mlmodel = ct.convert(
            int32_float32_two_input_model,
            inputs=[
                ct.TensorType(name="input1", shape=(10, 20)),
                ct.TensorType(name="input2", shape=(10, 20)),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

    def test_iO16_default_fp16_output(
        self,
        float32_input_model_add_op,
        int32_input_model,
    ):
        """
        With minimum_deployment_target set >= iOS16, if the compute precision is
        set to fp16. By default, a fp16 i/o model is produced for fp32 models.
        However, if the users specify the dtype, the converter is going to respect that.
        """
        # Case 1: fp32 single output model
        mlmodel = ct.convert(
            float32_input_model_add_op,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16")
        verify_prediction(mlmodel)

        # Case 2: int32 single output model. No change made.
        mlmodel = ct.convert(
            int32_input_model,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="int32")
        verify_prediction(mlmodel)

        # Case 3: fp32 single output model, with dtype set by the user
        mlmodel = ct.convert(
            float32_input_model_add_op,
            outputs=[ct.TensorType(dtype=np.float32)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp32")
        verify_prediction(mlmodel)

    def test_iO16_default_fp16_multiple_output(
        self,
        float32_two_output_model,
        int32_two_output_model,
        int32_float32_two_output_model,
    ):
        output1_name = "Identity" if _HAS_TF_2 else "output1"
        output2_name = "Identity_1" if _HAS_TF_2 else "output2"

        # Case 1: fp32 two outputs model. First output dtype missing
        mlmodel = ct.convert(
            float32_two_output_model,
            outputs=[
                ct.TensorType(name=output1_name),
                ct.TensorType(name=output2_name, dtype=np.float32),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp32", index=1)
        verify_prediction(mlmodel)

        # Case 2: fp32 two outputs model. Second output dtype missing
        mlmodel = ct.convert(
            float32_two_output_model,
            outputs=[
                ct.TensorType(name=output1_name, dtype=np.int32),
                ct.TensorType(name=output2_name),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 3: fp32 two outputs model. Both output dtype missing
        mlmodel = ct.convert(
            float32_two_output_model,
            outputs=[
                ct.TensorType(name=output1_name),
                ct.TensorType(name=output2_name),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 4: fp32 two outputs model. outputs not set.
        mlmodel = ct.convert(
            float32_two_output_model,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 5: int32 two outputs model. outputs not set. No change happens.
        mlmodel = ct.convert(
            int32_two_output_model,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_output_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)

        # Case 6: int32 two outputs model. The first input is force set to fp32.
        # In this case, the first output is inferred as fp32 as well, so it defaults
        # to fp16.
        mlmodel = ct.convert(
            int32_two_output_model,
            inputs=[
                ct.TensorType(name="input1", dtype=np.float32),
                ct.TensorType(name="input2"),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_output_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)

        # Case 7: int32 two outputs model. The second input is force set to fp16.
        # In this case, the second output is inferred as fp32 as well, so it defaults
        # to fp16.
        mlmodel = ct.convert(
            int32_two_output_model,
            inputs=[
                ct.TensorType(name="input1"),
                ct.TensorType(name="input2", dtype=np.float16),
            ],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=1)
        verify_prediction(mlmodel)

        # Case 8: two outputs model with int32/fp32.
        # In this case, the fp32 output defaults to fp16, while the int32 one remains unchanged.
        mlmodel = ct.convert(
            int32_float32_two_output_model,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_output_dtype(mlmodel, expected_type_str="fp16", index=0)
        assert_output_dtype(mlmodel, expected_type_str="int32", index=1)
        verify_prediction(mlmodel)

    def test_iO17_default_fp32_io(
        self,
        int32_float32_two_input_model,
        int32_float32_two_output_model,
    ):
        """
        With minimum_deployment_target set >= iOS16, and if the compute precision is
        set to fp32. By default, a fp32 i/o model is produced.
        """
        # Example 1
        mlmodel = ct.convert(
            int32_float32_two_input_model,
            compute_precision=ct.precision.FLOAT32,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="int32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=1)
        assert_output_dtype(mlmodel, expected_type_str="fp32", index=0)

        # Example 2
        mlmodel = ct.convert(
            int32_float32_two_output_model,
            compute_precision=ct.precision.FLOAT32,
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=0)
        assert_input_dtype(mlmodel, expected_type_str="fp32", index=1)
        assert_output_dtype(mlmodel, expected_type_str="fp32", index=0)
        assert_output_dtype(mlmodel, expected_type_str="int32", index=1)

    def test_iO16_default_image_dtype_input(
        self,
        rank4_input_model,
        rank4_grayscale_input_model,
    ):
        """
        We keep the input dtype for the image input model to fp32, unless it is GRAYSCALE_FLOAT16
        """
        # Example 1
        mlmodel = ct.convert(
            rank4_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.RGB
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)

        # Example 2
        mlmodel = ct.convert(
            rank4_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.BGR)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.BGR
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)

        # Example 3
        mlmodel = ct.convert(
            rank4_grayscale_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_spec_input_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)

        # Example 4
        mlmodel = ct.convert(
            rank4_grayscale_input_model,
            inputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_spec_input_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        verify_prediction(mlmodel)

    def test_iO16_default_image_dtype_output(
        self,
        rank4_input_model_with_channel_first_output,
        rank4_grayscale_input_model_with_channel_first_output,
    ):
        """
        We keep the output dtype for the image input model to fp32, unless it is GRAYSCALE_FLOAT16
        """
        # Example 1
        mlmodel = ct.convert(
            rank4_input_model_with_channel_first_output,
            outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.RGB
        )
        verify_prediction(mlmodel)

        # Example 2
        mlmodel = ct.convert(
            rank4_input_model_with_channel_first_output,
            outputs=[ct.ImageType(color_layout=ct.colorlayout.BGR)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.BGR
        )
        verify_prediction(mlmodel)

        # Example 3
        mlmodel = ct.convert(
            rank4_grayscale_input_model_with_channel_first_output,
            outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp32")
        assert_spec_output_image_type(
            mlmodel._spec, expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE
        )
        verify_prediction(mlmodel)

        # Example 4
        mlmodel = ct.convert(
            rank4_grayscale_input_model_with_channel_first_output,
            outputs=[ct.ImageType(color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)],
            minimum_deployment_target=ct.target.iOS16,
        )
        assert_prog_input_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_prog_output_type(mlmodel._mil_program, expected_dtype_str="fp16")
        assert_spec_output_image_type(
            mlmodel._spec,
            expected_feature_type=proto.FeatureTypes_pb2.ImageFeatureType.GRAYSCALE_FLOAT16,
        )
        verify_prediction(mlmodel)
